import warnings

warnings.filterwarnings("ignore")

import os, copy
import argparse
import json
import pickle
import numpy as np
import torch
import torchaudio
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel
from env import AttrDict, build_env
from meldataset import MelDataset, MelDataset_ADV, mel_spectrogram, get_dataset_filelist, get_dataset_filelist_libri_adv

from adversarial.adversarial_utils import get_token2text
from adversarial.base_optim import BaseOptimization
from utils import constant
from utils.spec import spectral_normalize_torch, mel_spectrogram
from utils.functions import load_model_adv
from utils.rir import Readrir, create_speech_rir
from librosa.filters import mel as librosa_mel_fn

mel_basis = {}
hann_window = {}

def project_into_constraint(spec, fixed_mag_spec):
    mag_spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
    ratio_k = torch.sqrt(fixed_mag_spec / (mag_spec + 1e-9))
    ratio_k = ratio_k.repeat(2, 1, 1).permute(1, 2, 0)
    return spec * ratio_k

def griffinlim(phase_spec, mag_spec, n_fft=1024, hop_length=256, win_length=1024, n_iter=30):
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
    Implementation ported from
    *librosa* [:footcite:`brian_mcfee-proc-scipy-2015`], *A fast Griffin-Lim algorithm* [:footcite:`6701851`]
    and *Signal estimation from modified short-time Fourier transform* [:footcite:`1172092`].
    Args:
        specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
            where freq is ``n_fft // 2 + 1``.
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
        n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
        hop_length (int): Length of hop between STFT windows. (
            Default: ``win_length // 2``)
        win_length (int): Window size. (Default: ``n_fft``)
        power (float): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
        n_iter (int): Number of iteration for phase recovery process.
        momentum (float): The momentum parameter for fast Griffin-Lim.
            Setting this to 0 recovers the original Griffin-Lim method.
            Values near 1 can lead to faster convergence, but above 1 may not converge.
        length (int or None): Array length of the expected output.
        rand_init (bool): Initializes phase randomly if True, to zero otherwise.
    Returns:
        torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
    """

    window = torch.hann_window(1024).to(phase_spec.device)

    recombine_magnitude_phase = torch.cat(
        [(mag_spec * torch.cos(phase_spec)).unsqueeze(-1), (mag_spec * torch.sin(phase_spec)).unsqueeze(-1)],
        dim=2)

    for _ in range(n_iter):
        # Invert with our current estimate of the phases
        inverse = torch.istft(recombine_magnitude_phase,
                              n_fft=n_fft,
                              hop_length=hop_length,
                              win_length=win_length,
                              window=window,
                              center=True, normalized=False, onesided=True)

        # Rebuild the spectrogram
        rebuilt = torch.stft(
            input=inverse,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=window,
            center=True,
            pad_mode='reflect',
            normalized=False,
            onesided=True
        )

        phase_spec = torch.atan2(rebuilt[:, :, 1], rebuilt[:, :, 0])

        recombine_magnitude_phase = torch.cat(
            [(mag_spec * torch.cos(phase_spec)).unsqueeze(-1), (mag_spec * torch.sin(phase_spec)).unsqueeze(-1)],
            dim=2)

    # # Return the final phase estimates
    # waveform = torch.istft(recombine_magnitude_phase,
    #                        n_fft=n_fft,
    #                        hop_length=hop_length,
    #                        win_length=win_length,
    #                        window=window,
    #                        center=True, normalized=False, onesided=True)
    return phase_spec

def _attack_2nd_stage(y, y_mel, adv_tgt, id2label, device, attack_model, model_adv, sample_num, max_iterations):
    n_fft = 1024
    num_mels = 80
    sampling_rate = 16000
    hop_size = 512
    win_size = 1024
    fmin = 0
    fmax = None
    center = True
    origin_y = y
    count = 0

    if torch.min(y) < -1.:
        print('min value is ', torch.min(y))
    if torch.max(y) > 1.:
        print('max value is ', torch.max(y))

    global mel_basis, hann_window
    if fmax not in mel_basis:
        mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
        mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).to(y.device)
        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)

    y = torch.nn.functional.pad(y.unsqueeze(0).unsqueeze(0), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
                                mode='reflect')
    y = y.squeeze(0).squeeze(0)

    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
                      center=center, pad_mode='reflect', normalized=False, onesided=True)
    mag_spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))


    import pydsm
    import librosa
    freqs = librosa.core.fft_frequencies(sr=16000, n_fft=1024)
    barks = 13 * np.arctan(0.00076 * freqs) + 3.5 * np.arctan(pow(freqs / 7500.0, 2))
    # Compute quiet threshold
    ath = np.zeros(len(barks), dtype=np.float64) - np.inf
    bark_idx = np.argmax(barks > 1)
    ath[bark_idx:] = (
            3.64 * pow(freqs[bark_idx:] * 0.001, -0.8)
            - 6.5 * np.exp(-0.6 * pow(0.001 * freqs[bark_idx:] - 3.3, 2))
            + 0.001 * pow(0.001 * freqs[bark_idx:], 4)
            - 12
    )
    ath = torch.Tensor(ath).to(device)

    psd = 20 * torch.log10(mag_spec)
    masks = torch.zeros_like(mag_spec).bool()
    i = 0.5
    for time_step in range(masks.shape[1]):
        for i in range(513):
            idx = i + 2
            if idx < 513:
                if idx == 0 or i == 512 and psd[idx, time_step] >= ath[idx]:
                    masks[i + torch.argmax(psd[idx, time_step]), time_step] = True
                    if time_step + 1 < masks.shape[1]:
                        masks[i + torch.argmax(psd[idx, time_step]), time_step+1] = True
                        masks[i + torch.argmax(psd[idx, time_step]), time_step-1] = True
                elif psd[idx, time_step] >= ath[idx]:
                    masks[i + torch.argmax(psd[i:idx, time_step]), time_step] = True
                    if time_step + 1 < masks.shape[1]:
                        masks[i + torch.argmax(psd[i:idx, time_step]), time_step+1] = True
                        masks[i + torch.argmax(psd[i:idx, time_step]), time_step-1] = True

    print(masks[masks==1].shape)
    print(mag_spec.shape)

    phase_spec = torch.atan2(spec[:,:,1].data, spec[:,:,0].data)

    recombine_magnitude_phase = torch.cat(
        [(mag_spec * torch.cos(phase_spec)).unsqueeze(-1), (mag_spec * torch.sin(phase_spec)).unsqueeze(-1)], dim=2)

    y_new = torch.istft(recombine_magnitude_phase, n_fft, hop_length=hop_size, win_length=win_size,
                        window=hann_window[str(y.device)],
                        center=center, normalized=False, onesided=True)
    pad = (y_new.shape[0] - origin_y.shape[0]) // 2

    # epsilon = torch.zeros_like(spec).to(y.device)
    epsilon = torch.nn.Parameter(torch.zeros_like(phase_spec).to(y.device))
    optim = torch.optim.Adam([epsilon], lr=5e-2, betas=(0.9, 0.999), eps=1e-9)

    success = False
    alpha = 1e-4
    successful_epsilon = None
    for iter in range(max_iterations):
        epsilon.requires_grad = True
        phase_spec_adv = phase_spec + epsilon

        recombine_magnitude_phase = torch.cat(
            [(mag_spec * torch.cos(phase_spec_adv)).unsqueeze(-1), (mag_spec * torch.sin(phase_spec_adv)).unsqueeze(-1)], dim=2)
        y_new = torch.istft(recombine_magnitude_phase, n_fft, hop_length=hop_size, win_length=win_size,
                            window=hann_window[str(y.device)],
                            center=center, normalized=False, onesided=True)
        y_new = torch.clamp(y_new, -1, 1)

        spec_new = torch.stft(y_new, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
                              center=center, pad_mode='reflect', normalized=False, onesided=True)
        mag_spec_new = torch.sqrt(spec_new.pow(2).sum(-1) + (1e-9))

        y_new = torch.clamp(y_new, -1, 1)[pad:y_new.shape[0] - pad]

        # First stage
        mel_spec_adv = mel_spectrogram(y_new.squeeze().float(), 1024, 80, 16000, 256,1024,0, None)
        loss_1st_stage, num_correct, pred_txt, target_txt = attack_model.get_adv_loss(
            mel_spec_adv.unsqueeze(0).unsqueeze(0),
            torch.IntTensor([mel_spec_adv.shape[1]]),
            torch.FloatTensor([1.0]),
            adv_tgt, torch.IntTensor([len(adv_tgt)]),
            id2label, model_adv)

        power_phase = (2+2*torch.cos(epsilon))/4
        relu = torch.nn.ReLU()
        loss_2st_stage = relu((0.98 - power_phase[masks])).sum() + relu((0.9 - power_phase[~masks])).sum() * 0.05

        # loss = loss_1st_stage + power_phase.sum() * alpha
        loss = loss_1st_stage + loss_2st_stage * alpha

        # if str(num_correct) != str(adv_tgt.squeeze().shape[0]):
        #     loss = loss_1st_stage
        # else:
        #     loss = loss_1st_stage + loss_2st_stage * 2e-5
        optim.zero_grad()
        # if epsilon.grad is not None:
        #     epsilon.grad.data.zero_()
        loss.backward()

        if epsilon.grad is not None:
            epsilon.grad[torch.isnan(epsilon.grad)] = 0
            # epsilon.grad.data[:,1::2] = 0
            # epsilon.grad.data[masks] = 0
            epsilon.grad.data = epsilon.grad.data.sign()
        optim.step()


        print("Sample num: {:d}, Steps: {:d}, Loss adv: {:f}, Loss GD: {:f}, Accuracy: {:s} / {:s}".format(
            sample_num, iter, loss_1st_stage.item(), loss_2st_stage.item(), str(num_correct), str(adv_tgt.squeeze().shape[0])))

        if str(num_correct) == str(adv_tgt.squeeze().shape[0]):
            success = True
            successful_epsilon = epsilon.clone().detach()
            if iter % 10 == 0:
                alpha *= 1.2

    if not success:
        import sys
        sys.exit(0)

    return y_new


def train(rank, a, h):
    if h.num_gpus > 1:
        init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
                           world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)

    torch.cuda.manual_seed(h.seed)
    device = torch.device('cuda:{:d}'.format(rank))

    model_adv, _, _, _, loaded_args, label2id, id2label = load_model_adv(a.adv_continue_from)
    attack_model = BaseOptimization(constant.args)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    
    total_f1_loss = 0
    for i in range(a.num_attack):
        if i >= 0:

            start.record()

            x, y, filename, y_mel, label, pitch, energy, adv_tgt = pickle.load(open('adv_dataset/' + str(i) + '.dump', 'rb'))

            y = torch.autograd.Variable(y.to(device, non_blocking=True))
            y = y.unsqueeze(1)

            y_adv = _attack_2nd_stage((y.to(device, non_blocking=True)).squeeze(), y_mel, adv_tgt, id2label, device,
                                      attack_model, model_adv, i, a.max_iterations)

            label_text = get_token2text(label, constant.args, id2label)
            target_text = get_token2text(adv_tgt, constant.args, id2label)

            if not os.path.exists(a.checkpoint_path):
                os.mkdir(a.checkpoint_path)
            if y_adv != None:
                with open(a.checkpoint_path + '/label.txt', 'a+') as f:
                    f.write(label_text + "|" + target_text + "\n")
                torchaudio.save(
                    a.checkpoint_path + '/' + str(i) + '_gt.wav',
                    (y).cpu().squeeze().unsqueeze(0), sample_rate=h.sampling_rate)
                torchaudio.save(a.checkpoint_path + '/' + str(i) + '.wav',
                                y_adv.cpu().squeeze().unsqueeze(0), sample_rate=h.sampling_rate)

            mel_spec_adv = mel_spectrogram(y_adv.squeeze(), 1024, 80, 16000, 256, 1024, 0, None)
            total_f1_loss += F.l1_loss(y_mel.detach().cpu(), mel_spec_adv.detach().cpu()).item()

            end.record()
            print(start.elapsed_time(end))
            with open('time.txt', 'a+') as f:
                    f.write('phasefool' + "|" + str(start.elapsed_time(end)) + "\n")
    
    print(total_f1_loss)


def main():
    print('Initializing Training Process..')

    parser = argparse.ArgumentParser()

    parser.add_argument('--group_name', default=None)
    parser.add_argument('--input_mels_dir', default='ft_dataset')
    parser.add_argument('--input_training_file', default='libri_train_clean_manifest.csv')
    parser.add_argument('--input_validation_file', default='libri_test_clean_manifest.csv')
    parser.add_argument('--checkpoint_path', default='result')
    parser.add_argument('--config', default='config_v1_libri.json')
    parser.add_argument('--fine_tuning', default=False, type=bool)
    parser.add_argument('--adv_continue_from',
                        default='save/libri_TRFS/epoch_241.th')
    parser.add_argument('--num_attack', default=100, type=int)
    parser.add_argument('--max_iterations', default=500, type=int)

    a = parser.parse_args()

    with open(a.config) as f:
        data = f.read()

    json_config = json.loads(data)
    h = AttrDict(json_config)
    h.sampling_rate = 16000
    build_env(a.config, 'config.json', a.checkpoint_path)

    torch.manual_seed(h.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(h.seed)
        h.num_gpus = torch.cuda.device_count()
        h.batch_size = int(h.batch_size / h.num_gpus)
        print('Batch size per GPU :', h.batch_size)
    else:
        pass

    if h.num_gpus > 1:
        mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
    else:
        train(0, a, h)


if __name__ == '__main__':
    main()
